{ "cells": [ { "cell_type": "markdown", "id": "yfiI-lSZRuOP", "metadata": { "id": "yfiI-lSZRuOP" }, "source": [ "### **1. S-learner**\n", "\n", "\n", "The first estimator we would like to introduce is the S-learner, also known as a ``single learner\". This is one of the most foundamental learners in HTE esitmation, and is very easy to implement.\n", "\n", "Under three common assumptions in causal inference, i.e. (1) consistency, (2) no unmeasured confounders (NUC), (3) positivity assumption, the heterogeneous treatment effect can be identified by the observed data, where\n", "\\begin{equation*}\n", "\\tau(s)=\\mathbb{E}[R|S,A=1]-\\mathbb{E}[R|S,A=0].\n", "\\end{equation*}\n", "\n", "The basic idea of S-learner is to fit a model for $\\mathbb{E}[R|S,A]$, and then construct a plug-in estimator for it. Specifically, the algorithm can be summarized as below:\n", "\n", "**Step 1:** Estimate the response function $\\mu(s,a):=\\mathbb{E}[R|S=s,A=a]$ with any supervised machine learning algorithm;\n", "\n", "**Step 2:** The estimated HTE of S-learner is given by \n", "\\begin{equation*}\n", "\\hat{\\tau}_{\\text{S-learner}}(s)=\\hat\\mu(s,1)-\\hat\\mu(s,0).\n", "\\end{equation*}\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "id": "eRpP5k9MBtzO", "metadata": { "ExecuteTime": { "end_time": "2023-11-12T12:59:55.276902Z", "start_time": "2023-11-12T12:59:54.456667Z" }, "id": "eRpP5k9MBtzO" }, "outputs": [], "source": [ "# import related packages\n", "import numpy as np\n", "import pandas as pd\n", "from matplotlib import pyplot as plt;\n", "from sklearn.ensemble import GradientBoostingRegressor\n", "from sklearn.linear_model import LinearRegression\n", "from causaldm.learners.CEL.Single_Stage import _env_getdata_CEL" ] }, { "cell_type": "markdown", "id": "XUu695Qrf61-", "metadata": { "id": "XUu695Qrf61-" }, "source": [ "### MovieLens Data" ] }, { "cell_type": "code", "execution_count": 2, "id": "JhfJntzcVVy2", "metadata": { "ExecuteTime": { "end_time": "2023-11-12T12:59:55.353538Z", "start_time": "2023-11-12T12:59:55.278183Z" }, "colab": { "base_uri": "https://localhost:8080/", "height": 424 }, "executionInfo": { "elapsed": 288, "status": "ok", "timestamp": 1676750101543, "user": { "displayName": "Yang Xu", "userId": "12270366590264264299" }, "user_tz": 300 }, "id": "JhfJntzcVVy2", "outputId": "7fab8a7a-7cd9-445c-a005-9a6d1994a071" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
user_idmovie_idratingageDramagender_Moccupation_academic/educatoroccupation_college/grad studentoccupation_executive/managerialoccupation_otheroccupation_technician/engineer
048.01193.04.025.01.01.00.01.00.00.00.0
148.0919.04.025.01.01.00.01.00.00.00.0
248.0527.05.025.01.01.00.01.00.00.00.0
348.01721.04.025.01.01.00.01.00.00.00.0
448.0150.04.025.01.01.00.01.00.00.00.0
....................................
656375878.03300.02.025.00.00.00.00.00.01.00.0
656385878.01391.01.025.00.00.00.00.00.01.00.0
656395878.0185.04.025.00.00.00.00.00.01.00.0
656405878.02232.01.025.00.00.00.00.00.01.00.0
656415878.0426.03.025.00.00.00.00.00.01.00.0
\n", "

65642 rows × 11 columns

\n", "
" ], "text/plain": [ " user_id movie_id rating age Drama gender_M \\\n", "0 48.0 1193.0 4.0 25.0 1.0 1.0 \n", "1 48.0 919.0 4.0 25.0 1.0 1.0 \n", "2 48.0 527.0 5.0 25.0 1.0 1.0 \n", "3 48.0 1721.0 4.0 25.0 1.0 1.0 \n", "4 48.0 150.0 4.0 25.0 1.0 1.0 \n", "... ... ... ... ... ... ... \n", "65637 5878.0 3300.0 2.0 25.0 0.0 0.0 \n", "65638 5878.0 1391.0 1.0 25.0 0.0 0.0 \n", "65639 5878.0 185.0 4.0 25.0 0.0 0.0 \n", "65640 5878.0 2232.0 1.0 25.0 0.0 0.0 \n", "65641 5878.0 426.0 3.0 25.0 0.0 0.0 \n", "\n", " occupation_academic/educator occupation_college/grad student \\\n", "0 0.0 1.0 \n", "1 0.0 1.0 \n", "2 0.0 1.0 \n", "3 0.0 1.0 \n", "4 0.0 1.0 \n", "... ... ... \n", "65637 0.0 0.0 \n", "65638 0.0 0.0 \n", "65639 0.0 0.0 \n", "65640 0.0 0.0 \n", "65641 0.0 0.0 \n", "\n", " occupation_executive/managerial occupation_other \\\n", "0 0.0 0.0 \n", "1 0.0 0.0 \n", "2 0.0 0.0 \n", "3 0.0 0.0 \n", "4 0.0 0.0 \n", "... ... ... \n", "65637 0.0 1.0 \n", "65638 0.0 1.0 \n", "65639 0.0 1.0 \n", "65640 0.0 1.0 \n", "65641 0.0 1.0 \n", "\n", " occupation_technician/engineer \n", "0 0.0 \n", "1 0.0 \n", "2 0.0 \n", "3 0.0 \n", "4 0.0 \n", "... ... \n", "65637 0.0 \n", "65638 0.0 \n", "65639 0.0 \n", "65640 0.0 \n", "65641 0.0 \n", "\n", "[65642 rows x 11 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Get the MovieLens data\n", "MovieLens_CEL = _env_getdata_CEL.get_movielens_CEL()\n", "MovieLens_CEL.pop(MovieLens_CEL.columns[0])\n", "\n", "# Remove irrelevant columns\n", "MovieLens_CEL = MovieLens_CEL[MovieLens_CEL.columns.drop(['Comedy','Action', 'Thriller', 'Sci-Fi'])]\n", "MovieLens_CEL" ] }, { "cell_type": "markdown", "id": "5dfc86fd", "metadata": {}, "source": [ "In this selected dataset, we only consider two movie genres for comparison: `Drama` and `Sci-Fi`. That is, the users not watching `Drama` movies are exposed to `Sci-Fi` movies instead." ] }, { "cell_type": "code", "execution_count": 3, "id": "J__3Ozs7Uxxs", "metadata": { "ExecuteTime": { "end_time": "2023-11-12T12:59:55.356942Z", "start_time": "2023-11-12T12:59:55.353703Z" }, "id": "J__3Ozs7Uxxs" }, "outputs": [], "source": [ "n = len(MovieLens_CEL)\n", "userinfo_index = np.array([3,5,6,7,8,9,10])\n", "SandA = MovieLens_CEL.iloc[:, np.array([3,4,5,6,7,8,9,10])]" ] }, { "cell_type": "code", "execution_count": 4, "id": "h5G8dAwM-PGO", "metadata": { "ExecuteTime": { "end_time": "2023-11-12T12:59:56.313195Z", "start_time": "2023-11-12T12:59:55.358551Z" }, "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 837, "status": "ok", "timestamp": 1676750134359, "user": { "displayName": "Yang Xu", "userId": "12270366590264264299" }, "user_tz": 300 }, "id": "h5G8dAwM-PGO", "outputId": "affb7b39-83cd-4d7e-8572-02cbce6be447" }, "outputs": [ { "data": { "text/plain": [ "GradientBoostingRegressor(max_depth=5)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# S-learner\n", "np.random.seed(0)\n", "S_learner = GradientBoostingRegressor(max_depth=5)\n", "#S_learner = LinearRegression()\n", "#SandA = np.hstack((S.to_numpy(),A.to_numpy().reshape(-1,1)))\n", "S_learner.fit(SandA, MovieLens_CEL['rating'])" ] }, { "cell_type": "code", "execution_count": 5, "id": "Vqsb5wLTaR0q", "metadata": { "ExecuteTime": { "end_time": "2023-11-12T12:59:56.374355Z", "start_time": "2023-11-12T12:59:56.306220Z" }, "id": "Vqsb5wLTaR0q" }, "outputs": [], "source": [ "SandA_all1 = SandA.copy()\n", "SandA_all0 = SandA.copy()\n", "SandA_all1.iloc[:,1]=np.ones(n)\n", "SandA_all0.iloc[:,1]=np.zeros(n)\n", "\n", "HTE_S_learner = S_learner.predict(SandA_all1) - S_learner.predict(SandA_all0)\n" ] }, { "cell_type": "markdown", "id": "FA-F8Jc_T5Lz", "metadata": { "id": "FA-F8Jc_T5Lz" }, "source": [ "Let's focus on the estimated HTEs for three randomly chosen users:" ] }, { "cell_type": "code", "execution_count": 6, "id": "GvHnTOxmT5Lz", "metadata": { "ExecuteTime": { "end_time": "2023-11-12T12:59:56.379378Z", "start_time": "2023-11-12T12:59:56.374874Z" }, "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 318, "status": "ok", "timestamp": 1676750150517, "user": { "displayName": "Yang Xu", "userId": "12270366590264264299" }, "user_tz": 300 }, "id": "GvHnTOxmT5Lz", "outputId": "7b0b76fd-f5ac-4ab8-a3c0-188e15484fe7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "S-learner: [0.36103861 0.35479314 0.35916424]\n" ] } ], "source": [ "print(\"S-learner: \",HTE_S_learner[np.array([0,1000,5000])])" ] }, { "cell_type": "code", "execution_count": 7, "id": "651e4f8a", "metadata": { "ExecuteTime": { "end_time": "2023-11-12T12:59:56.382574Z", "start_time": "2023-11-12T12:59:56.378309Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Choosing Drama instead of Sci-Fi is expected to improve the rating of all users by 0.3563 out of 5 points.\n" ] } ], "source": [ "ATE_S_learner = np.sum(HTE_S_learner)/n\n", "print(\"Choosing Drama instead of Sci-Fi is expected to improve the rating of all users by\",round(ATE_S_learner,4), \"out of 5 points.\")" ] }, { "cell_type": "markdown", "id": "mVAZTZYTUKJ6", "metadata": { "id": "mVAZTZYTUKJ6" }, "source": [ "**Conclusion:** As we can see from the estimated ATE by S-learner, people are more inclined to give higher ratings to drama than science fictions. " ] }, { "cell_type": "markdown", "id": "nyirbjS5JdGh", "metadata": { "id": "nyirbjS5JdGh" }, "source": [ "## References\n", "1. Kunzel, S. R., Sekhon, J. S., Bickel, P. J., and Yu, B. (2019). Metalearners for estimating heterogeneous treatment effects using machine learning. Proceedings of the national academy of sciences 116, 4156–4165.\n" ] }, { "cell_type": "code", "execution_count": 7, "id": "71caea76a1b91036", "metadata": { "ExecuteTime": { "end_time": "2023-11-12T12:59:56.383402Z", "start_time": "2023-11-12T12:59:56.381371Z" } }, "outputs": [], "source": [] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 5 }